{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "input_collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from bqplot import (\n",
    "    Axis, ColorAxis, LinearScale, DateScale, DateColorScale, OrdinalScale,\n",
    "    OrdinalColorScale, ColorScale, Scatter, Lines, Figure, Tooltip\n",
    ")\n",
    "from ipywidgets import Label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "price_data = pd.DataFrame(np.cumsum(np.random.randn(150, 2).dot([[1.0, -0.8], [-0.8, 1.0]]), axis=0) + 100,\n",
    "                          columns=['Security 1', 'Security 2'], index=pd.date_range(start='01-01-2007', periods=150))\n",
    "size = 100\n",
    "np.random.seed(0)\n",
    "x_data = range(size)\n",
    "y_data = np.cumsum(np.random.randn(size) * 100.0)\n",
    "ord_keys = np.array(['A', 'B', 'C', 'D', 'E', 'F'])\n",
    "ordinal_data = np.random.randint(5, size=size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "input_collapsed": false
   },
   "outputs": [],
   "source": [
    "symbols = ['Security 1', 'Security 2']\n",
    "\n",
    "dates_all = price_data.index.values\n",
    "dates_all_t = dates_all[1:]\n",
    "sec1_levels = np.array(price_data[symbols[0]].values.flatten())\n",
    "log_sec1 = np.log(sec1_levels)\n",
    "sec1_returns = log_sec1[1:] - log_sec1[:-1]\n",
    "\n",
    "sec2_levels = np.array(price_data[symbols[1]].values.flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Scatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc_x = DateScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "scatt = Scatter(x=dates_all, y=sec2_levels, scales={'x': sc_x, 'y': sc_y})\n",
    "ax_x = Axis(scale=sc_x, label='Date')\n",
    "ax_y = Axis(scale=sc_y, orientation='vertical', tick_format='0.0f', label='Security 2')\n",
    "\n",
    "Figure(marks=[scatt], axes=[ax_x, ax_y])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Changing the marker and adding text to each point of the scatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Changing the marker as \n",
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "scatt = Scatter(x=x_data[:10], y=y_data[:10], names=np.arange(10),\n",
    "                scales={'x': sc_x, 'y': sc_y}, colors=['red'], marker='cross')\n",
    "ax_x = Axis(scale=sc_x)\n",
    "ax_y = Axis(scale=sc_y, orientation='vertical', tick_format='0.2f')\n",
    "\n",
    "Figure(marks=[scatt], axes=[ax_x, ax_y], padding_x=0.025)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Changing the opacity of each marker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatt.opacities = [0.3, 0.5, 1.]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Representing additional dimensions of data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "input_collapsed": false
   },
   "source": [
    "## Linear Scale for Color Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "input_collapsed": false
   },
   "outputs": [],
   "source": [
    "sc_x = DateScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "sc_c1 = ColorScale()\n",
    "scatter = Scatter(x=dates_all, y=sec2_levels, color=sec1_returns,\n",
    "                  scales={'x': sc_x, 'y': sc_y, 'color': sc_c1}, \n",
    "                  stroke='black')\n",
    "\n",
    "ax_y = Axis(label='Security 2', scale=sc_y, \n",
    "            orientation='vertical', side='left')\n",
    "\n",
    "ax_x = Axis(label='Date', scale=sc_x, num_ticks=10, label_location='end')\n",
    "ax_c = ColorAxis(scale=sc_c1, tick_format='0.2%', label='Returns', orientation='vertical', side='right')\n",
    "\n",
    "m_chart = dict(top=50, bottom=70, left=50, right=100)\n",
    "\n",
    "Figure(axes=[ax_x, ax_c, ax_y], marks=[scatter], fig_margin=m_chart,\n",
    "       title='Scatter of Security 2 vs Dates')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Changing the default color. \n",
    "scatter.colors = ['blue']  # In this case, the dot with the highest X changes to blue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## setting the fill to be empty\n",
    "scatter.stroke = None\n",
    "scatter.fill = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Setting the fill back\n",
    "scatter.stroke = 'black'\n",
    "scatter.fill = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Changing the color to a different variable\n",
    "scatter.color = sec2_levels\n",
    "ax_c.tick_format = '0.0f'\n",
    "ax_c.label = 'Security 2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Changing the range of the color scale\n",
    "sc_c1.colors = ['blue', 'green', 'orange']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Date Scale for Color Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "sc_c1 = DateColorScale(scheme='Reds')\n",
    "scatter = Scatter(x=sec2_levels, y=sec1_levels, color=dates_all,\n",
    "                  scales={'x': sc_x, 'y': sc_y, 'color': sc_c1}, default_size=128,\n",
    "                  stroke='black')\n",
    "\n",
    "ax_y = Axis(label='Security 1 Level', scale=sc_y, orientation='vertical', side='left')\n",
    "\n",
    "ax_x = Axis(label='Security 2', scale=sc_x)\n",
    "ax_c = ColorAxis(scale=sc_c1, label='Date', num_ticks=5)\n",
    "\n",
    "m_chart = dict(top=50, bottom=80, left=50, right=50)\n",
    "Figure(axes=[ax_x, ax_c, ax_y], marks=[scatter], fig_margin=m_chart)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "input_collapsed": false
   },
   "source": [
    "## Ordinal Scale for Color"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "input_collapsed": false
   },
   "outputs": [],
   "source": [
    "factor = int(np.ceil(len(sec2_levels) * 1.0 / len(ordinal_data)))\n",
    "ordinal_data = np.tile(ordinal_data, factor)\n",
    "\n",
    "c_ord = OrdinalColorScale(colors=['DodgerBlue', 'SeaGreen', 'Yellow', 'HotPink', 'OrangeRed'])\n",
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "scatter2 = Scatter(x=sec2_levels[1:],\n",
    "                   y=sec1_returns,\n",
    "                   color=ordinal_data,\n",
    "                   scales={'x': sc_x, 'y': sc_y, 'color': c_ord}, \n",
    "                   legend='__no_legend__',\n",
    "                   stroke='black')\n",
    "\n",
    "ax_y = Axis(label='Security 1 Returns', scale=sc_y, orientation='vertical', tick_format='.0%')\n",
    "\n",
    "ax_x = Axis(label='Security 2', scale=sc_x, label_location='end')\n",
    "ax_c = ColorAxis(scale=c_ord, label='Class', side='right', orientation='vertical')\n",
    "\n",
    "m_chart = dict(top=50, bottom=70, left=100, right=100)\n",
    "Figure(axes=[ax_x, ax_y, ax_c], marks=[scatter2], fig_margin=m_chart)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax_c.tick_format = '0.2f'\n",
    "c_ord.colors = ['blue', 'red', 'green', 'yellow', 'orange']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "input_collapsed": false
   },
   "source": [
    "## Setting size and opacity based on data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "input_collapsed": false
   },
   "outputs": [],
   "source": [
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "sc_y2 = LinearScale()\n",
    "\n",
    "sc_size = LinearScale()\n",
    "sc_opacity = LinearScale()\n",
    "\n",
    "\n",
    "scatter2 = Scatter(x=sec2_levels[1:], y=sec1_levels, size=sec1_returns,\n",
    "                   scales={'x': sc_x, 'y': sc_y, 'size': sc_size, 'opacity': sc_opacity}, \n",
    "                   default_size=128, colors=['orangered'], stroke='black')\n",
    "\n",
    "ax_y = Axis(label='Security 1', scale=sc_y, orientation='vertical', side='left')\n",
    "ax_x = Axis(label='Security 2', scale=sc_x)\n",
    "\n",
    "Figure(axes=[ax_x, ax_y], marks=[scatter2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Changing the opacity of the scatter\n",
    "scatter2.opacities = [0.5, 0.3, 0.1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Resetting the size for the scatter\n",
    "scatter2.size=None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Resetting the opacity and setting the opacity according to the date\n",
    "scatter2.opacities = [1.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter2.opacity = dates_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Changing the skew of the marker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "sc_e = LinearScale()\n",
    "\n",
    "scatter = Scatter(scales={'x': sc_x, 'y': sc_y, 'skew': sc_e}, \n",
    "                  x=sec2_levels[1:], y=sec1_levels,\n",
    "                  skew=sec1_returns, stroke=\"black\",\n",
    "                  colors=['gold'], default_size=200, \n",
    "                  marker='rectangle', default_skew=0)\n",
    "\n",
    "ax_y = Axis(label='Security 1', scale=sc_y, orientation='vertical', side='left')\n",
    "ax_x = Axis(label='Security 2', scale=sc_x)\n",
    "\n",
    "Figure(axes=[ax_x, ax_y], marks=[scatter], animation_duration=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter.skew = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter.skew = sec1_returns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rotation scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "sc_e = LinearScale()\n",
    "sc_c = ColorScale(scheme='Reds')\n",
    "x1 = np.linspace(-1, 1, 30)\n",
    "y1 = np.linspace(-1, 1, 30)\n",
    "x, y = np.meshgrid(x1,y1)\n",
    "x, y = x.flatten(), y.flatten()\n",
    "rot = x**2 + y**2\n",
    "color=x-y\n",
    "scatter = Scatter(scales={'x': sc_x, 'y': sc_y, 'color': sc_c, 'rotation': sc_e}, \n",
    "                  x=x, y=y, rotation=rot, color=color,\n",
    "                  stroke=\"black\", default_size=200, \n",
    "                  marker='arrow', default_skew=0.5,)\n",
    "Figure(marks=[scatter], animation_duration=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter.rotation = 1.0 / (x ** 2 + y ** 2 + 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scatter Chart Interactions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Moving points in Scatter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Enabling moving of points in scatter. Try to click and drag any of the points in the scatter and \n",
    "## notice the line representing the mean of the data update\n",
    "sc_x = LinearScale()\n",
    "sc_y = LinearScale()\n",
    "\n",
    "scat = Scatter(x=x_data[:10], y=y_data[:10], scales={'x': sc_x, 'y': sc_y}, colors=['orange'],\n",
    "               enable_move=True)\n",
    "lin = Lines(x=[], y=[], scales={'x': sc_x, 'y': sc_y}, line_style='dotted', colors=['orange'])\n",
    "\n",
    "def update_line(change=None):\n",
    "    with lin.hold_sync():\n",
    "        lin.x = [np.min(scat.x), np.max(scat.x)]\n",
    "        lin.y = [np.mean(scat.y), np.mean(scat.y)]\n",
    "\n",
    "update_line()\n",
    "# update line on change of x or y of scatter\n",
    "scat.observe(update_line, names=['x'])\n",
    "scat.observe(update_line, names=['y'])\n",
    "\n",
    "ax_x = Axis(scale=sc_x)\n",
    "ax_y = Axis(scale=sc_y, tick_format='0.2f', orientation='vertical')\n",
    "\n",
    "fig = Figure(marks=[scat, lin], axes=[ax_x, ax_y])\n",
    "fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Updating X and Y while moving the point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## In this case on drag, the line updates as you move the points.\n",
    "scat.update_on_move = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latex_widget = Label(color='Green', font_size='16px')\n",
    "\n",
    "def callback_help(name, value):\n",
    "    latex_widget.value = str(value)\n",
    "    \n",
    "latex_widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scat.on_drag_start(callback_help)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scat.on_drag(callback_help)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scat.on_drag_end(callback_help)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Restricting movement to only along the Y-axis\n",
    "scat.restrict_y = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adding/Deleting points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Enabling adding the points to Scatter. Try clicking anywhere on the scatter to add points\n",
    "with scat.hold_sync():\n",
    "    scat.enable_move = False\n",
    "    scat.interactions = {'click': 'add'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Switching between interactions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import ToggleButtons, VBox\n",
    "\n",
    "interact_control = ToggleButtons(options=['Add', 'Delete', 'Drag XY', 'Drag X', 'Drag Y'],\n",
    "                                 style={'button_width': '120px'})\n",
    "\n",
    "def change_interact(change):\n",
    "    interact_parameters = {\n",
    "     'Add': {'interactions': {'click': 'add'},\n",
    "             'enable_move': False},\n",
    "     'Delete': {'interactions': {'click': 'delete'},\n",
    "                'enable_move': False},\n",
    "     'Drag XY': {'interactions': {'click': None},\n",
    "                 'enable_move': True,\n",
    "                 'restrict_x': False,\n",
    "                 'restrict_y': False},\n",
    "     'Drag X': {'interactions': {'click': None},\n",
    "                'enable_move': True,\n",
    "                'restrict_x': True,\n",
    "                'restrict_y': False},\n",
    "     'Drag Y': {'interactions': {'click': None},\n",
    "                'enable_move': True,\n",
    "                'restrict_x': False,\n",
    "                'restrict_y': True}\n",
    "    }\n",
    "    for param, value in interact_parameters[interact_control.value].items():\n",
    "        setattr(scat, param, value)\n",
    "        \n",
    "interact_control.observe(change_interact, names='value')\n",
    "\n",
    "fig.title = 'Adding/Deleting/Moving points'\n",
    "VBox([fig, interact_control])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Custom event on end of drag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Whenever drag is ended, there is a custom event dispatched which can be listened to.\n",
    "## try dragging a point and see the data associated with the event being printed\n",
    "def test_func(self, content):\n",
    "    print(\"received drag end\", content)\n",
    "\n",
    "scat.on_drag_end(test_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Adding tooltip and custom hover style"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_sc = LinearScale()\n",
    "y_sc = LinearScale()\n",
    "\n",
    "x_data = x_data[:50]\n",
    "y_data = y_data[:50]\n",
    "\n",
    "def_tt = Tooltip(fields=['x', 'y'], formats=['', '.2f'])\n",
    "\n",
    "scatter_chart = Scatter(x=x_data, y=y_data, scales= {'x': x_sc, 'y': y_sc}, colors=['dodgerblue'],\n",
    "                        tooltip=def_tt, unhovered_style={'opacity': 0.5})\n",
    "ax_x = Axis(scale=x_sc)\n",
    "ax_y = Axis(scale=y_sc, orientation='vertical', tick_format='0.2f')\n",
    "\n",
    "Figure(marks=[scatter_chart], axes=[ax_x, ax_y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## removing field names from the tooltip\n",
    "def_tt.show_labels = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## changing the fields displayed in the tooltip\n",
    "def_tt.fields = ['y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def_tt.fields = ['x']"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
